#!/usr/bin/env python3
import argparse, os, json, time, subprocess, queue, threading
import pandas as pd
from pathlib import Path

KIDS_DEFAULT = "data/KiDS_DR4.1_ugriZYJHKs_SOM_gold_WL_cat.fits"
RG_EDGES = "1.5,3,5,8,12"
MS_EDGES = "10.2,10.5,10.8,11.1"
B_EDGES  = "10,15,22,32,46,66,95,137,198,285,410,592,855,1236,1787,2583"

def norm_key(s): return s.replace('[','').replace(']','').replace(' ','').replace(',','-')

def split_bins(in_csv, out_dir):
    df = pd.read_csv(in_csv)
    rg = 'RG_bin' if 'RG_bin' in df.columns else ('R_G_bin' if 'R_G_bin' in df.columns else None)
    ms = 'Mstar_bin' if 'Mstar_bin' in df.columns else ('M_bin' if 'M_bin' in df.columns else None)
    if not rg or not ms: raise SystemExit(f"Missing RG_bin/Mstar_bin in {in_csv}")
    Path(out_dir).mkdir(parents=True, exist_ok=True)
    idx = []
    for r in sorted(df[rg].astype(str).unique()):
        for m in sorted(df[ms].astype(str).unique()):
            sub = df[(df[rg].astype(str)==r) & (df[ms].astype(str)==m)]
            if sub.empty: continue
            fn = Path(out_dir)/f"lenses_rg{norm_key(r)}_ms{norm_key(m)}.csv"
            sub.to_csv(fn, index=False)
            idx.append({"rg_bin":r,"ms_bin":m,"rows":int(len(sub)),"path":str(fn)})
            print(f"[bin] {r} × {m}: {len(sub):,} -> {fn}")
    Path(out_dir,"index.json").write_text(json.dumps(idx,indent=2))
    print(f"[OK] wrote {len(idx)} bin files -> {out_dir}/index.json")
    return idx

def concat(parts, metas, target):
    out_all  = f"data/prestacked_stacks_{'lens' if target=='lenses' else 'rand'}.csv"
    meta_all = f"data/prestacked_meta_{'lens' if target=='lenses' else 'rand'}.csv"
    import pandas as pd
    dfs = [pd.read_csv(p) for p in parts if os.path.exists(p)]
    mfs = [pd.read_csv(m) for m in metas if os.path.exists(m)]
    if dfs: pd.concat(dfs, ignore_index=True).to_csv(out_all, index=False)
    if mfs: pd.concat(mfs, ignore_index=True).to_csv(meta_all, index=False)
    print(f"[OK] concatenated -> {out_all} ({sum(map(len,dfs)) if dfs else 0} rows), {meta_all}")

def run_bin(cmd, log):
    with open(log,"w") as lf:
        return subprocess.run(cmd, stdout=lf, stderr=subprocess.STDOUT).returncode

def worker(q, results):
    while True:
        try: job=q.get_nowait()
        except queue.Empty: return
        t0=time.time()
        rc=run_bin(job["cmd"], job["log"])
        job["rc"]=rc; job["dt"]=time.time()-t0
        results.append(job); q.task_done()

def save_prog(path, state): Path(path).write_text(json.dumps(state,indent=2))

def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--target", choices=["lenses","randoms"], required=True)
    ap.add_argument("--kids", default=KIDS_DEFAULT)
    ap.add_argument("--workers", type=int, default=max(1,(os.cpu_count() or 8)//3))
    ap.add_argument("--use-m-corr", action="store_true")
    ap.add_argument("--resume", action="store_true")
    args=ap.parse_args()

    target=args.target
    in_csv = "data/lenses.csv" if target=="lenses" else "data/lenses_random.csv"
    base   = f"work/bins/{target}"
    logs   = "logs"; Path(logs).mkdir(exist_ok=True)
    prog   = f"data/prestack_progress_{target}.json"

    idx_path=Path(base,"index.json")
    bins = json.loads(idx_path.read_text()) if idx_path.exists() else split_bins(in_csv, base)

    completed=set()
    if args.resume and os.path.exists(prog):
        try: completed=set(json.loads(Path(prog).read_text()).get("completed",[]))
        except: pass

    parts=[]; metas=[]; jobs=[]
    for b in bins:
        key=f"{b['rg_bin']}|{b['ms_bin']}"
        part=f"{base}/prestack_{target}_{Path(b['path']).stem}.csv"
        meta=f"{base}/meta_{target}_{Path(b['path']).stem}.csv"
        log =f"{logs}/prestack_{target}_{Path(b['path']).stem}.log"
        parts.append(part); metas.append(meta)
        if key in completed and os.path.exists(part): continue
        cmd=["python","scripts/prestack_kids.py","--kids",args.kids,"--lenses",b["path"],
             "--out",part,"--out-meta",meta,"--rg-bins","1.5,3,5,8,12",
             "--mstar-bins","10.2,10.5,10.8,11.1","--b-bins-arcsec", 
             "10,15,22,32,46,66,95,137,198,285,410,592,855,1236,1787,2583","--min-zsep","0.1"]
        if args.use_m_corr: cmd.append("--use-m-corr")
        jobs.append({"key":key,"cmd":cmd,"log":log,"rows":b["rows"]})

    total=len(bins); started=time.time()
    q=queue.Queue(); [q.put(j) for j in jobs]
    results=[]; N=max(1,int(args.workers))
    print(f"[run] {target}: bins={total} plan={q.qsize()} workers={N}")
    threads=[threading.Thread(target=worker, args=(q,results), daemon=True) for _ in range(N)]
    [t.start() for t in threads]
    last=0
    while any(t.is_alive() for t in threads):
        time.sleep(2)
        done=len(completed)+len([r for r in results if "rc" in r])
        dts=[r["dt"] for r in results if r.get("dt",0)>0]
        eta=(sum(dts)/len(dts)*(total-done)) if dts else None
        save_prog(prog,{"target":target,"total":total,"done":done,
                        "completed":sorted(list(completed))+[r["key"] for r in results if r.get("rc",0)==0],
                        "elapsed_sec":round(time.time()-started,1),
                        "eta_sec":round(eta,1) if eta else None})
    errs=[r for r in results if r.get("rc",0)!=0]
    for r in results:
        if r.get("rc",0)==0: completed.add(r["key"])
    save_prog(prog,{"target":target,"total":total,"done":len(completed),
                    "completed":sorted(list(completed)),
                    "elapsed_sec":round(time.time()-started,1),
                    "errors":[{"key":e["key"],"log":e["log"]} for e in errs]})
    if errs:
        print(f"[ERR] {len(errs)} bin(s) failed; see logs/")
        raise SystemExit(1)
    concat(parts, metas, target)
    print("[OK] orchestrator finished.")
if __name__=="__main__": main()
